{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow 变量命名管理机制（二）\n",
    "\n",
    "\n",
    "### 1. 用 collection 来聚合变量\n",
    "前面介绍了 tensorflow 的变量命名机制，这里要补充一下 `tf.add_to_collection` 和 `tf.get_collection`的用法。\n",
    "\n",
    "因为神经网络中的参数非常多，有时候我们 只想对某些参数进行操作。除了前面通过变量名称的方法来获取参数之外， TensorFlow 中还有 collection 这么一种操作。\n",
    "\n",
    "collection 可以聚合多个**变量**或者**操作**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参考：[如何利用tf.add_to_collection、tf.get_collection以及tf.add_n来简化正则项的计算](https://blog.csdn.net/weixin_39980291/article/details/78352125)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vars in col1: [<tf.Variable 'v1:0' shape=(3,) dtype=int32_ref>, <tf.Variable 'v3:0' shape=(2, 3) dtype=float32_ref>]\n"
     ]
    }
   ],
   "source": [
    "# 把变量添加到一个 collection 中\n",
    "v1 = tf.Variable([1,2,3], name='v1')\n",
    "v2 = tf.Variable([2], name='v2')\n",
    "v3 = tf.get_variable(name='v3', shape=(2,3))\n",
    "\n",
    "tf.add_to_collection('col1', v1) # 把 v1 添加到 col1 中\n",
    "tf.add_to_collection('col1', v3)\n",
    "\n",
    "col1s = tf.get_collection(key='col1')  # 获取 col1 的变量\n",
    "print('vars in col1:', col1s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "除了把变量添加到集合中，还可以把操作添加到集合中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vars in col1: [<tf.Variable 'v1:0' shape=(3,) dtype=int32_ref>, <tf.Variable 'v3:0' shape=(2, 3) dtype=float32_ref>, <tf.Tensor 'add_op:0' shape=(3,) dtype=int32>]\n"
     ]
    }
   ],
   "source": [
    "op1 = tf.add(v1, 2, name='add_op')\n",
    "tf.add_to_collection('col1', op1)\n",
    "col1s = tf.get_collection(key='col1')  # 获取 col1 的变量\n",
    "print('vars in col1:', col1s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，还可以加上 scope 的约束。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vars in col1 with scope=model:  [<tf.Variable 'model/v5:0' shape=(3,) dtype=int32_ref>]\n"
     ]
    }
   ],
   "source": [
    "with tf.variable_scope('model'):\n",
    "    v4 = tf.get_variable('v4', shape=[3,4])\n",
    "    v5 = tf.Variable([1,2,3], name='v5')\n",
    "\n",
    "tf.add_to_collection('col1', v5)\n",
    "col1_vars = tf.get_collection(key='col1', scope='model')  # 获取 col1 的变量\n",
    "print('vars in col1 with scope=model: ', col1_vars)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. tf.GraphKeys \n",
    "\n",
    "参考：[tf.GraphKeys 函数](https://www.w3cschool.cn/tensorflow_python/tensorflow_python-ne7t2ezd.html)\n",
    "\n",
    "用于图形集合的标准名称。\n",
    "\n",
    "\n",
    "标准库使用各种已知的名称来收集和检索与图形相关联的值。例如，如果没有指定，则 tf.Optimizer 子类默认优化收集的变量tf.GraphKeys.TRAINABLE_VARIABLES，但也可以传递显式的变量列表。\n",
    "\n",
    "定义了以下标准键：\n",
    "\n",
    "- GLOBAL_VARIABLES：默认的 Variable 对象集合，在分布式环境共享（模型变量是其中的子集）。参考：tf.global_variables。通常，所有TRAINABLE_VARIABLES 变量都将在 MODEL_VARIABLES，所有 MODEL_VARIABLES 变量都将在 GLOBAL_VARIABLES。\n",
    "- LOCAL_VARIABLES：每台计算机的局部变量对象的子集。通常用于临时变量，如计数器。注意：使用 tf.contrib.framework.local_variable 添加到此集合。\n",
    "- MODEL_VARIABLES：在模型中用于推理（前馈）的变量对象的子集。注意：使用 tf.contrib.framework.model_variable 添加到此集合。\n",
    "- TRAINABLE_VARIABLES：将由优化器训练的变量对象的子集。\n",
    "- SUMMARIES：在关系图中创建的汇总张量对象。\n",
    "- QUEUE_RUNNERS：用于为计算生成输入的 QueueRunner 对象。\n",
    "- MOVING_AVERAGE_VARIABLES：变量对象的子集，它也将保持移动平均值。\n",
    "- REGULARIZATION_LOSSES：在图形构造期间收集的正规化损失。\n",
    "\n",
    "这个知道就好了，要用的时候知道是怎么回事就行了。比如在 BN 层中就有这东西。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
